{-
Copyright (c) 2018-2021 VMware, Inc.
SPDX-License-Identifier: MIT

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-}

{-# LANGUAGE DeriveGeneric, RecordWildCards, LambdaCase, FlexibleContexts, OverloadedStrings, QuasiQuotes, ImplicitParams, TupleSections #-}

{- |
Module     : OVSDB.Compile
Description: Compile 'OVSDB schema' to DDlog.
-}

module Language.DifferentialDatalog.OVSDB.Compile (
    Config(..),
    defaultConfig,
    OutputRelConfig,
    compileSchema,
    compileSchemaFile) where

import Prelude hiding((<>), readFile, writeFile)
import Text.PrettyPrint
import Text.RawString.QQ
import Data.Word
import Data.Maybe
import Data.Char
import Data.List
import Data.Aeson (FromJSON, ToJSON)
import GHC.Generics (Generic)

import Language.DifferentialDatalog.OVSDB.Parse
import Language.DifferentialDatalog.Parse
import Language.DifferentialDatalog.Util
import Language.DifferentialDatalog.Name
import Language.DifferentialDatalog.Pos
import Language.DifferentialDatalog.PP
import Language.DifferentialDatalog.Error
import Control.Monad.Except

{-
The compiler generates 5 kinds of DDlog tables: Input, Output, Delta+, Delta-, DeltaUpdate:

Input: input tables that have the same schema as corresponding OVSDB tables

Output: these are the tables computed by DDlog program.  They have the same
    schema as corresponding OVSDB tables.

Delta+: new records to be inserted to OVSDB.
Delta-: records to be deleted from OVSDB.
DeltaUpdate: records to be modified in OVSDB.

The last three tables are generated by comparing input and output tables.  If a record
with a given UUID exists in the input table, but not the output table, this record must
be deleted.  UUIDs found in the output table, but not the input table, must be created.
Records present in both tables with the same UUID but different values in other columns
must be updated.

Note that the user must only provide rules to compute Output relations based on Input
relations; the compiler generates rules to compute all other relations.

The following diagram illustrates dependencies among these relations.

      (user-defined rules)
Input--------------------> Output------> Delta+
                              |
                              |------------> Delta-
                              |
                              +----------> DeltaUpdate
-}


-- Type of OVSDB tables.
data OVSDBTableKind = OVSDBTableInput      -- DDlog reads, but does not modify the table.
                    | OVSDBTableOutputOnly -- DDlog generates updates to the table, but never reads its current state.
                    | OVSDBTableOutput     -- DDlog generates updates to the table as deltas between its current and
                                           -- new contents.

-- Type of DDlog table generated by this program.  Multiple DDlog tables can be
-- generated for a  single OVSDB table.
data TableKind = TableInput
               | TableOutput
               | TableDirectOutput
               | TableDeltaPlus
               | TableDeltaMinus
               | TableDeltaUpdate
               deriving(Eq)

data Config = Config { ovsSchemaFile    :: FilePath
                     , outputFile       :: Maybe FilePath
                     , outputTables     :: [OutputRelConfig]
                     , outputOnlyTables :: [String]
                     , multisetTables   :: [String]
                     , internedTables   :: [String]
                     , internStrings    :: Bool
                     }
              deriving (Eq, Show, Generic)

instance FromJSON Config
instance ToJSON   Config

defaultConfig :: Config
defaultConfig = Config { ovsSchemaFile    = ""
                       , outputFile       = Nothing
                       , outputTables     = []
                       , outputOnlyTables = []
                       , multisetTables   = []
                       , internedTables   = []
                       , internStrings    = False
                       }

-- | Output relation configuration:
-- (name, Left <set_of_readonly_columns> or Right <Set of rw columns>)
type OutputRelConfig = (String, Either [String] [String])

builtins :: String
builtins = [r|import ovsdb
|]

-- generate a legal DDlog relation name from a table name
relationName :: Table -> Doc
relationName t = let (c:rest) = name t in
                 pp $ (toUpper c) : rest

-- generate annotations necessary for a table declaration
tableAnnotation :: Table -> Doc
tableAnnotation t = if relationName t == (pp $ name t)
                    then empty else text $ "#[original = \"" ++ name t ++ "\"]"

-- generate the relation name used for a specific table
mkTableName :: (?schema::OVSDBSchema, ?config::Config) => Table -> TableKind -> Doc
mkTableName t tkind =
    let rn = relationName t in
    case tkind of
         TableInput       -> rn
         TableOutput      -> "Out_" <> rn
         TableDirectOutput-> "Out_" <> rn
         TableDeltaPlus   -> "DeltaPlus_" <> rn
         TableDeltaMinus  -> "DeltaMinus_" <> rn
         TableDeltaUpdate -> "Update_" <> rn

compileSchemaFile :: FilePath -> Config -> IO Doc
compileSchemaFile fname config = do
    content <- readFile fname
    schema <- case parseSchema content fname of
                   Left  e    -> compilerError $ "failed to parse input file: " ++ e
                   Right prog -> return prog
    case compileSchema schema config of
         Left e    -> compilerError e
         Right doc -> return doc

compileSchema :: (MonadError String me) => OVSDBSchema -> Config -> me Doc
compileSchema schema config = do
    let tables = schemaTables schema
    mapM_ (\i -> do let t = find ((==i) . name) tables
                    when (isNothing t) $ throwError $ "Table '" ++ i ++ "' not found") $ internedTables config
    mapM_ (\(o, _) -> do let t = find ((==o) . name) tables
                         when (isNothing t) $ throwError $ "Table '" ++ o ++ "' not found") $ outputTables config
    mapM_ (\o -> do let t = find ((==o) . name) tables
                    when (isNothing t) $ throwError $ "Table '" ++ o ++ "' not found") $ outputOnlyTables config
    uniqNames Nothing ("Multiple declarations of table " ++ ) tables
    let ?schema = schema
    let ?config = config
    rels <- ((\(inp, outp, priv) -> vcat
                                    $ ["/* Input relations */\n"] ++ inp ++
                                      ["\n/* Output relations */\n"] ++ outp ++
                                      ["\n/* Delta tables definitions */\n"] ++ priv) . unzip3) <$>
            mapM (\t -> mkTable (ovsdbTableKind config t) t) tables
    return $ pp builtins $+$ "" $+$ rels

ovsdbTableKind :: Config -> Table -> OVSDBTableKind
ovsdbTableKind Config{..} t | elem (name t) outputOnlyTables        = OVSDBTableOutputOnly
                            | isJust $ lookup (name t) outputTables = OVSDBTableOutput
                            | otherwise                             = OVSDBTableInput

mkTable :: (?schema::OVSDBSchema, ?config::Config, MonadError String me) => OVSDBTableKind -> Table  -> me (Doc, Doc, Doc)
mkTable tkind t@Table{..} = do
    ovscols <- tableCheckCols t
    maybe (return ())
          (\restrictions -> let rcols = case restrictions of
                                             Left ro -> ro
                                             Right rw -> rw in
                            mapM_ (\c -> when (isNothing $ find ((== c) . name) (tableGetCols t))
                                        $ throwError $ "Column " ++ c ++ " not found in table " ++ name t) rcols)
          $ lookup (name t) $ outputTables ?config
    uniqNames Nothing (\col -> "Multiple declarations of column " ++ col ++ " in table " ++ tableName) ovscols
    case tkind of
       OVSDBTableInput -> (, empty, empty) <$> mkTable' TableInput t
       OVSDBTableOutputOnly ->
               (empty, , empty) <$> mkTable' TableDirectOutput t
       OVSDBTableOutput -> do
           output           <- mkTable' TableOutput         t
           input            <- mkTable' TableInput          t
           delta_plus       <- mkTable' TableDeltaPlus      t
           let delta_plus_rules = mkDeltaPlusRules          t
           delta_minus      <- mkTable' TableDeltaMinus     t
           let delta_minus_rules = mkDeltaMinusRules        t
           delta_update     <- mkTable' TableDeltaUpdate    t
           let delta_update_rules = mkDeltaUpdateRules      t
           return $ (empty,
                     output,
                     (delta_plus   $+$ delta_plus_rules)                                   $+$
                     -- Values from non-root tables will be deleted automatically by OVSDB,
                     -- so we don't have to pay the cost of computing delta-minus tables.
                     (if tableIsRoot t then delta_minus $+$ delta_minus_rules else empty)  $+$
                     delta_update  $+$ delta_update_rules                                  $+$
                     input)

mkTable' :: (?schema::OVSDBSchema, ?config::Config, MonadError String me) => TableKind -> Table -> me Doc
mkTable' tkind t@Table{..} = do
    let ovscols = tableGetCols t
    let writable_cols = tableGetNonROCols t
    let prefix = case tkind of
                      TableInput       -> "input"
                      TableDeltaPlus   -> "output"
                      TableDeltaMinus  -> "output"
                      TableDeltaUpdate -> "output"
                      TableOutput      -> empty
                      TableDirectOutput-> "output"
    let tname = mkTableName t tkind
    let cols = case tkind of
                    TableInput        -> ovscols
                    TableDeltaPlus    -> writable_cols
                    TableDeltaMinus   -> [uuidCol]
                    TableDeltaUpdate  -> writable_cols
                    TableOutput       -> writable_cols
                    TableDirectOutput -> writable_cols
    let tflavor = case tkind of
                       TableDirectOutput | elem (name t) (multisetTables ?config) 
                                         -> "multiset"
                       _                 -> "relation"
    columns <- mapM (mkCol tkind tableName) cols
    let key = if tkind == TableInput
                 then "primary key (x) x._uuid"
                 else empty
    let annotation = tableAnnotation t
    if tableIsInterned t && tkind == TableInput
    then return $ "typedef" <+> pp tname <+> "=" <+> pp tname <+> "{"                   $$
                  (nest' $ vcommaSep columns)                                           $$
                  "}"                                                                   $$
                  annotation                                                            $$
                  prefix <+> tflavor <+> pp tname <+> "[Intern<" <> pp tname <> ">]" $$
                  key
    else return $ annotation                                    $$
                  prefix <+> tflavor <+> pp tname <+> "("       $$
                  (nest' $ vcommaSep columns)                   $$
                  ")"                                           $$
                  key

tableIsInterned :: (?config::Config) => Table -> Bool
tableIsInterned t = elem (name t) $ internedTables ?config

mkDeltaPlusRules :: (?schema::OVSDBSchema, ?config::Config) => Table -> Doc
mkDeltaPlusRules t =
    (mkTableName t TableDeltaPlus) <> "(" <> commaSep cols <> ") :-"        $$
    (nest' $ mkTableName t TableOutput <> "(" <> commaSep cols <> "),")     $$
    (nest' $ "not " <> deref <> mkTableName t TableInput <> "(._uuid = _uuid).")
    where
    deref = if tableIsInterned t then "&" else empty
    nonro_cols = tableGetNonROCols t
    cols = map (\c -> "." <> mkColName c <+> "=" <+> mkColName c) nonro_cols

-- DeltaMinus(uuid) :- Input(uuid, key, _), not Output(_, key, _).
mkDeltaMinusRules :: (?schema::OVSDBSchema, ?config::Config) => Table ->  Doc
mkDeltaMinusRules t =
    (mkTableName t TableDeltaMinus) <> "(uuid) :-"                          $$
    (nest' $ deref <> mkTableName t TableInput <> "(._uuid = uuid),")       $$
    (nest' $ "not" <+> mkTableName t TableOutput <> "(._uuid = uuid).")
    where
    deref = if tableIsInterned t then "&" else empty

-- DeltaUpdate(uuid, new) :- Output(uuid, new), Input(uuid, old), old != new.
mkDeltaUpdateRules :: (?schema::OVSDBSchema, ?config::Config) => Table -> Doc
mkDeltaUpdateRules t =
    (mkTableName t TableDeltaUpdate) <> "(" <> commaSep outcols <> ") :-"               $$
    (nest' $ mkTableName t TableOutput <> "(" <> commaSep outcols <> "),")              $$
    (nest' $ deref <> mkTableName t TableInput <> "(" <> commaSep realcols <> "),")              $$
    (nest' $ (parens $ commaSep old_vars) <+> "!=" <+> (parens $ commaSep new_vars) <> ".")
    where
    deref = if tableIsInterned t then "&" else empty
    nonro_cols = tableGetNonROCols t
    outcols = map (\c -> let n = mkColName c in
                         "." <> n <+> "=" <+> (if n == "_uuid" then n else "__new_" <> n))
                  nonro_cols
    realcols = map (\c -> let n = mkColName c in
                         "." <> n <+> "=" <+> (if n == "_uuid" then n else "__old_" <> n))
                  nonro_cols
    new_vars = map (\c -> let n = mkColName c in
                          if n == "_uuid" then n else "__new_" <> n)
                  nonro_cols
    old_vars = map (\c -> let n = mkColName c in
                          if n == "_uuid" then n else "__old_" <> n)
                  nonro_cols

mkCol :: (?schema::OVSDBSchema, ?config::Config, MonadError String me) => TableKind -> String -> TableColumn -> me Doc
mkCol tkind tname c@TableColumn{..} = do
    checkNoProg (not $ elem columnName __reservedNames) (pos c) $ "Illegal column name " ++ columnName ++ " in table " ++ tname
    t <- case columnType of
              ColumnTypeAtomic at  -> mkAtomicType at
              ColumnTypeComplex ct -> mkComplexType tkind ct
              ColumnTypeUndefined  -> error "OVSDB.Compile.mkCol: undefine column type"
    return $ mkColName c <> ":" <+> t

__reservedNames :: [String]
__reservedNames = map (("__" ++) . map toLower) $ reservedNames

mkColName :: TableColumn -> Doc
mkColName c = mkColName' $ name c

mkColName' :: String -> Doc
mkColName' c =
    if elem x reservedNames
       then pp $ "__" ++ x
       else pp x
    where x = map toLower c


mkAtomicType :: (?config::Config, MonadError String me) => AtomicType -> me Doc
mkAtomicType IntegerType{}          = return "integer"
mkAtomicType RealType{}             = return "double"
mkAtomicType BooleanType{}          = return "bool"
mkAtomicType StringType{}           =
  case internStrings ?config of
       False -> return "string"
       True  -> return "istring"
mkAtomicType UUIDType{}             = return "uuid"
mkAtomicType UndefinedAtomicType{}  = error "OVSDB.Compile.mkAtomicType: undefined atomic type"

complexTypeBounds :: ComplexType -> (Integer, Integer)
complexTypeBounds ComplexType{..} = (min_bound, max_bound)
    where
    min_bound = maybe 1 id minComplexType
    max_bound = maybe 1
                (\case
                  Some x    -> x
                  Unlimited -> fromIntegral (maxBound::Word64))
                maxComplexType

mkComplexType :: (?schema::OVSDBSchema, ?config::Config, MonadError String me) => TableKind -> ComplexType -> me Doc
mkComplexType tkind t@ComplexType{..} = do
    let (min_bound, max_bound) = complexTypeBounds t
    checkNoProg (max_bound >= min_bound) (pos t) $ "min bound exceeds max bound"
    checkNoProg (min_bound == 0 || min_bound == 1) (pos t) $ "min bound must be 0 or 1"
    checkNoProg (max_bound > 0) (pos t) $ "max bound must be greater than 0"
    checkNoProg (max_bound /= 1 || isNothing valueComplexType) (pos t)
          $ "Cannot handle key-value pairs when max bound is 1"
    key <- mkBaseType tkind keyComplexType
    case (min_bound, max_bound) of
         (1,1) -> return key
         (0,1) -> return $ "Option<" <> key <> ">"
         _     -> do
             case valueComplexType of
                  Nothing -> return $ "Set<" <> key <> ">"
                  Just v  -> do vt <- mkBaseType tkind v
                                return $ "Map<" <> key <> "," <> vt <> ">"

mkBaseType :: (?schema::OVSDBSchema, ?config::Config, MonadError String me) =>  TableKind -> BaseType -> me Doc
mkBaseType _     (BaseTypeSimple at)   = mkAtomicType at
mkBaseType _     (BaseTypeComplex cbt) | isJust (refTableBaseType cbt)
                                       = return "uuid"
mkBaseType _     (BaseTypeComplex cbt) = mkAtomicType $ typeBaseType cbt
mkBaseType _     BaseTypeUndefined     = error "OVSDB.Compile.mkBaseType: undefined base type"

tableCheckCols :: (MonadError String me) => Table -> me [TableColumn]
tableCheckCols t@Table{..} = do
    let tprops = filter (\case
                          ColumnsProperty{} -> True
                          _                 -> False) tableProperties
    checkNoProg (not $ null tprops) (pos t) $ "Table " ++ tableName ++ " does not have a \"columns\" property"
    checkNoProg (length tprops == 1) (pos t) $ "Table " ++ tableName ++ " has multiple \"columns\" properties"
    let (ColumnsProperty ovscols) : _ = tprops
    return ovscols

tableGetCols :: Table -> [TableColumn]
tableGetCols Table{..} = uuidCol:ovscols
    where
    (ColumnsProperty ovscols) : _ = filter (\case
                                             ColumnsProperty{} -> True
                                             _                 -> False) tableProperties

uuidCol :: TableColumn
uuidCol = TableColumn { columnPos       = nopos
                      , columnName      = "_uuid"
                      , columnType      = ColumnTypeAtomic (UUIDType nopos)
                      , columnEphemeral = Nothing
                      , columnMutable   = Nothing
                      }

tableGetNonROCols :: (?config::Config) => Table -> [TableColumn]
tableGetNonROCols t =
    case lookup (name t) $ outputTables ?config of
         Nothing         -> ovscols
         Just (Left ro)  -> filter (\col -> notElem (name col) ro) ovscols
         Just (Right rw) -> filter (\col -> -- _uuid is always writable, as we use it as primary key.
                                            elem (name col) rw || (name col == "_uuid")) ovscols
    where
    ovscols = tableGetCols t

tableIsRoot :: Table -> Bool
tableIsRoot Table{..} =
    any (\case
          RootProperty True -> True
          _                 -> False) tableProperties
